import logging
from typing import Any

from pydantic import BaseModel, Field, PrivateAttr, validate_call

import weave
from weave.flow.scorer import WeaveScorerResult
from weave.scorers.default_models import OPENAI_DEFAULT_MODEL
from weave.scorers.scorer_types import HuggingFacePipelineScorer, LLMScorer
from weave.scorers.utils import MODEL_PATHS, load_local_model_weights, stringify

logger = logging.getLogger(__name__)


DEFAULT_HALLUCINATION_SYSTEM_PROMPT = """
Given some <input_data> from a user and an <output> generated by an AI system, \
determine if the <output> contains any hallucinations.

A "hallucination" is defined as information in the <output> that is not supported by \
the <input_data> or is not factually or logically consistent with the <input_data>.

# Steps
1. Carefully read and understand the input data.
2. Examine the model output.
3. Compare the output to the input data, identifying any inconsistencies or additions.
4. Evaluate the logical connection between input and output.
5. Determine if any information in the output is not supported by or conflicts with the input.

# Guidelines
- Focus on factual accuracy and logical consistency
- Consider both explicit and implicit information in the input data
- Be aware of potential misinterpretations or over-generalizations in the output
- Identify any information in the output that goes beyond the scope of the input

# Examples
## Data to analyze

<input_data_example>
The cat is black and white.
</input_data_example>

<output_example>
The cat has orange stripes.
</output_example>

## Analysis:
{
  "chain_of_thought": "The cat is black and white. The cat has orange stripes. \
The output contradicts the input data because the input specifies black and white, \
while the output mentions orange. The output also introduces a pattern not present in \
the input.",
  "reasoning": [
    {
      "hallucination_type": "Color comparison",
      "observation": "Input specifies black and white, output mentions orange"
    },
    {
      "hallucination_type": "Pattern analysis",
      "observation": "Input doesn't mention any pattern, output introduces stripes"
    }
  ],
  "conclusion": "The output contains two hallucinations: it contradicts the color information \
and introduces a pattern not present in the input.",
  "has_hallucination": true
}

# Notes
- Ensure each step in the reasoning process is clearly articulated
- Be objective and avoid assumptions not supported by the input data
- If the output contains factual information not present in the input, it may be a \
hallucination even if it doesn't directly contradict the input
"""

DEFAULT_HALLUCINATION_USER_PROMPT = """
Analyze the following <input_data> and <output> and determine if the <output> contains any hallucinations.
# Data to analyze

<input_data>
{input_data}
</input_data>

<output>
{output}
</output>
"""


class HallucinationReasoning(BaseModel):
    hallucination_type: str = Field(
        description="A short name for the type of hallucination."
    )
    observation: str = Field(
        description="An observation from the <input_data> and <output> that supports the hallucination."
    )


class HallucinationResponse(BaseModel):
    chain_of_thought: str = Field(
        description="Think step by step about whether the <output> contains hallucinations based on the <input_data>.",
    )
    reasonings: list[HallucinationReasoning] = Field(
        description="A list of reasoning steps that lead to the conclusion about whether or not the <output> contains hallucinations.",
    )
    conclusion: str = Field(description="The conclusion of the analysis.")
    has_hallucination: bool = Field(
        description="Indicates whether the <output> contains hallucinations based on the <input_data>. True means hallucinations are present."
    )


class HallucinationFreeScorer(LLMScorer):
    """A Scorer that uses an LLM to determine if the model output contains any hallucinations
    based on the input data.

    Note:
        - The meaning of "hallucination" can vary between users. \
          You may want to customize the `system_prompt` and `user_prompt` to suit your specific needs.
        - The scorer utilizes the `litellm.acompletion` function to generate structured outputs \
          from the LLM provider's response.
        - The `score` method expects the input column from the dataset to be named "context". \
          It will use this data as the ground-truth to check hallucinations against. \
          If your dataset's column has a different name, you can specify a different mapping using the \
          `column_map` argument in the __init__ of HallucinationFreeScorer, for example: `column_map={"context": "context"}`.

    Attributes:
        system_prompt (str): The prompt describing the task and defining what a "hallucination" is.
        user_prompt (str): The string template to pass the input and output data. The template must \
        contain placeholders for both `{input_data}` and `{output}`.
        model_id (str): The LLM model name, dependent on the LLM provider being used.
        temperature (float): Controls randomness in the LLM's responses (0.0 to 1.0)
        max_tokens (int): Maximum number of tokens allowed in the LLM's response

    Methods:
        score(output: str, context: str) -> HallucinationResponse:
            Analyzes the output to detect hallucinations based on the given context.
    """

    system_prompt: str = DEFAULT_HALLUCINATION_SYSTEM_PROMPT
    user_prompt: str = DEFAULT_HALLUCINATION_USER_PROMPT
    model_id: str = OPENAI_DEFAULT_MODEL
    temperature: float = Field(
        default=0.7,
        description="Controls randomness in the LLM's responses (0.0 to 1.0)",
    )
    max_tokens: int = Field(
        default=4096,
        description="Maximum number of tokens allowed in the LLM's response",
    )

    @weave.op
    async def score(self, *, output: str, context: str, **kwargs: Any) -> dict:
        output = stringify(output)
        response = await self._acompletion(
            messages=[
                {"role": "system", "content": self.system_prompt},
                {
                    "role": "user",
                    "content": self.user_prompt.format(
                        input_data=context, output=output
                    ),
                },
            ],
            model=self.model_id,
            response_format=HallucinationResponse,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )
        response = HallucinationResponse.model_validate_json(
            response.choices[0].message.content
        )
        return response.model_dump()


HALLUCINATION_SCORER_THRESHOLD = 0.35


class WeaveHallucinationScorerV1(HuggingFacePipelineScorer):
    """A scorer that detects hallucinations in the output, given an query and context. This scorer
    uses the HHEM 2.1 model from Vectara, https://huggingface.co/vectara/hallucination_evaluation_model.

    This scorer uses a fine-tuned LLM to analyze whether model outputs contain information not supported
    by the given context.

    Args:
        device: Device to run model on, defaults to "cuda"
        model_name_or_path: Path or name of model weights to load

    Note: This Scorer's `score` method will score the text passed to its `output` parameter for hallucinations
    based on the query and context.

    Returns:
        dict: A dictionary containing:
            - pass (bool): True if no hallucinations detected (score <= threshold)
            - extras (dict): Contains:
                - score (float): Hallucination score between 0 and 1, where higher is more hallucination
                - error (str, optional): Error message if something went wrong

    Example:
        >>> scorer = HallucinationScorer()
        >>> result = scorer.score(
        ...     query="What is the capital of France?",
        ...     context="Paris is the capital of France.",
        ...     output="Paris is the capital of France."
        ... )
        >>> print(result)
        {
            'pass': True,
            'extras': {
                'score': 0.1
            }
        }
    """

    threshold: float = Field(
        default=HALLUCINATION_SCORER_THRESHOLD,
        description="The threshold for the hallucination scorer.",
    )
    task: str = Field(
        default="pair-classification",
        description="The HF task name to use for the hallucination scorer pipeline.",
    )
    _model_max_length: int = PrivateAttr(default=8192)

    def load_pipeline(self) -> None:
        from transformers import pipeline

        self._local_model_path = load_local_model_weights(
            self.model_name_or_path, MODEL_PATHS["hallucination_scorer"]
        )

        self._pipeline = pipeline(
            task=self.task,
            model=self._local_model_path,
            device=self.device,
            trust_remote_code=True,
        )

    def _predict(self, query: str, context: str | list[str], output: str) -> float:
        assert self._pipeline is not None, (
            "Pipeline not loaded, check your `model_name_or_path`"
        )
        tokenizer = self._pipeline.tokenizer
        context_str = "\n\n".join(context) if isinstance(context, list) else context
        inps = query + "\n\n" + context_str
        outs = output

        inps_toks = tokenizer(inps, truncation=False)
        outs_toks = tokenizer(outs, truncation=False)

        len_inps = len(inps_toks.input_ids)
        len_outs = len(outs_toks.input_ids)

        # Handle large inputs
        if len_inps + len_outs > self._model_max_length:
            logger.info(
                f"sum of query, key and output tokens ({len_inps + len_outs}) > model_max_length ({self._model_max_length}), curtailing input query and context.."
            )
            # If the output is less than 1000 tokens, curtail the input query and context only
            if len_outs < self._model_max_length - 1000:
                inp_remaining = self._model_max_length - (len_outs + 975)
                inps_input_ids = inps_toks.input_ids[:inp_remaining]
                out_input_ids = outs_toks.input_ids
            else:
                # If the output is greater than 1000 tokens, curtail all 3,  query, context and output
                inps_input_ids = inps_toks.input_ids[:975]
                out_input_ids = outs_toks.input_ids[: self._model_max_length - 1025]

            inps = tokenizer.decode(inps_input_ids)
            outs = tokenizer.decode(out_input_ids)

        pred = self._pipeline((inps, outs))
        # Invert score so that higher means more hallucination
        score = 1 - pred
        return score

    @validate_call
    @weave.op
    def score(
        self,
        *,
        query: str,
        context: str | list[str],
        output: str,
        **kwargs: Any,
    ) -> WeaveScorerResult:
        """Score the hallucination of the query and context.

        Args:
            query: str, The query to score, must be a string
            context: Union[str, list[str]], The context to score, must be a string or list of strings
            output: str, The output string to score for hallucination given the query and context, must be a string
        """
        score = self._predict(query, context, output)
        passed = score < self.threshold
        return WeaveScorerResult(
            passed=passed,
            metadata={"score": score},
        )
